// Copyright 2016 The SwiftShader Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//    http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "ValidateLimitations.h"
#include "InfoSink.h"
#include "InitializeParseContext.h"
#include "ParseHelper.h"

namespace {
bool IsLoopIndex(const TIntermSymbol* symbol, const TLoopStack& stack) {
	for (TLoopStack::const_iterator i = stack.begin(); i != stack.end(); ++i) {
		if (i->index.id == symbol->getId())
			return true;
	}
	return false;
}

void MarkLoopForUnroll(const TIntermSymbol* symbol, TLoopStack& stack) {
	for (TLoopStack::iterator i = stack.begin(); i != stack.end(); ++i) {
		if (i->index.id == symbol->getId()) {
			ASSERT(i->loop);
			i->loop->setUnrollFlag(true);
			return;
		}
	}
	UNREACHABLE(0);
}

// Traverses a node to check if it represents a constant index expression.
// Definition:
// constant-index-expressions are a superset of constant-expressions.
// Constant-index-expressions can include loop indices as defined in
// GLSL ES 1.0 spec, Appendix A, section 4.
// The following are constant-index-expressions:
// - Constant expressions
// - Loop indices as defined in section 4
// - Expressions composed of both of the above
class ValidateConstIndexExpr : public TIntermTraverser {
public:
	ValidateConstIndexExpr(const TLoopStack& stack)
		: mValid(true), mLoopStack(stack) {}

	// Returns true if the parsed node represents a constant index expression.
	bool isValid() const { return mValid; }

	virtual void visitSymbol(TIntermSymbol* symbol) {
		// Only constants and loop indices are allowed in a
		// constant index expression.
		if (mValid) {
			mValid = (symbol->getQualifier() == EvqConstExpr) ||
			         IsLoopIndex(symbol, mLoopStack);
		}
	}

private:
	bool mValid;
	const TLoopStack& mLoopStack;
};

// Traverses a node to check if it uses a loop index.
// If an int loop index is used in its body as a sampler array index,
// mark the loop for unroll.
class ValidateLoopIndexExpr : public TIntermTraverser {
public:
	ValidateLoopIndexExpr(TLoopStack& stack)
		: mUsesFloatLoopIndex(false),
		  mUsesIntLoopIndex(false),
		  mLoopStack(stack) {}

	bool usesFloatLoopIndex() const { return mUsesFloatLoopIndex; }
	bool usesIntLoopIndex() const { return mUsesIntLoopIndex; }

	virtual void visitSymbol(TIntermSymbol* symbol) {
		if (IsLoopIndex(symbol, mLoopStack)) {
			switch (symbol->getBasicType()) {
			case EbtFloat:
				mUsesFloatLoopIndex = true;
				break;
			case EbtUInt:
				mUsesIntLoopIndex = true;
				MarkLoopForUnroll(symbol, mLoopStack);
				break;
			case EbtInt:
				mUsesIntLoopIndex = true;
				MarkLoopForUnroll(symbol, mLoopStack);
				break;
			default:
				UNREACHABLE(symbol->getBasicType());
			}
		}
	}

private:
	bool mUsesFloatLoopIndex;
	bool mUsesIntLoopIndex;
	TLoopStack& mLoopStack;
};
}  // namespace

ValidateLimitations::ValidateLimitations(GLenum shaderType,
                                         TInfoSinkBase& sink)
	: mShaderType(shaderType),
	  mSink(sink),
	  mNumErrors(0)
{
}

bool ValidateLimitations::visitBinary(Visit, TIntermBinary* node)
{
	// Check if loop index is modified in the loop body.
	validateOperation(node, node->getLeft());

	// Check indexing.
	switch (node->getOp()) {
	case EOpIndexDirect:
		validateIndexing(node);
		break;
	case EOpIndexIndirect:
		validateIndexing(node);
		break;
	default: break;
	}
	return true;
}

bool ValidateLimitations::visitUnary(Visit, TIntermUnary* node)
{
	// Check if loop index is modified in the loop body.
	validateOperation(node, node->getOperand());

	return true;
}

bool ValidateLimitations::visitAggregate(Visit, TIntermAggregate* node)
{
	switch (node->getOp()) {
	case EOpFunctionCall:
		validateFunctionCall(node);
		break;
	default:
		break;
	}
	return true;
}

bool ValidateLimitations::visitLoop(Visit, TIntermLoop* node)
{
	if (!validateLoopType(node))
		return false;

	TLoopInfo info;
	memset(&info, 0, sizeof(TLoopInfo));
	info.loop = node;
	if (!validateForLoopHeader(node, &info))
		return false;

	TIntermNode* body = node->getBody();
	if (body) {
		mLoopStack.push_back(info);
		body->traverse(this);
		mLoopStack.pop_back();
	}

	// The loop is fully processed - no need to visit children.
	return false;
}

void ValidateLimitations::error(TSourceLoc loc,
                                const char *reason, const char* token)
{
	mSink.prefix(EPrefixError);
	mSink.location(loc);
	mSink << "'" << token << "' : " << reason << "\n";
	++mNumErrors;
}

bool ValidateLimitations::withinLoopBody() const
{
	return !mLoopStack.empty();
}

bool ValidateLimitations::isLoopIndex(const TIntermSymbol* symbol) const
{
	return IsLoopIndex(symbol, mLoopStack);
}

bool ValidateLimitations::validateLoopType(TIntermLoop* node) {
	TLoopType type = node->getType();
	if (type == ELoopFor)
		return true;

	// Reject while and do-while loops.
	error(node->getLine(),
		  "This type of loop is not allowed",
		  type == ELoopWhile ? "while" : "do");
	return false;
}

bool ValidateLimitations::validateForLoopHeader(TIntermLoop* node,
                                                TLoopInfo* info)
{
	ASSERT(node->getType() == ELoopFor);

	//
	// The for statement has the form:
	//    for ( init-declaration ; condition ; expression ) statement
	//
	if (!validateForLoopInit(node, info))
		return false;
	if (!validateForLoopCond(node, info))
		return false;
	if (!validateForLoopExpr(node, info))
		return false;

	return true;
}

bool ValidateLimitations::validateForLoopInit(TIntermLoop* node,
                                              TLoopInfo* info)
{
	TIntermNode* init = node->getInit();
	if (!init) {
		error(node->getLine(), "Missing init declaration", "for");
		return false;
	}

	//
	// init-declaration has the form:
	//     type-specifier identifier = constant-expression
	//
	TIntermAggregate* decl = init->getAsAggregate();
	if (!decl || (decl->getOp() != EOpDeclaration)) {
		error(init->getLine(), "Invalid init declaration", "for");
		return false;
	}
	// To keep things simple do not allow declaration list.
	TIntermSequence& declSeq = decl->getSequence();
	if (declSeq.size() != 1) {
		error(decl->getLine(), "Invalid init declaration", "for");
		return false;
	}
	TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
	if (!declInit || (declInit->getOp() != EOpInitialize)) {
		error(decl->getLine(), "Invalid init declaration", "for");
		return false;
	}
	TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
	if (!symbol) {
		error(declInit->getLine(), "Invalid init declaration", "for");
		return false;
	}
	// The loop index has type int or float.
	TBasicType type = symbol->getBasicType();
	if (!IsInteger(type) && (type != EbtFloat)) {
		error(symbol->getLine(),
			  "Invalid type for loop index", getBasicString(type));
		return false;
	}
	// The loop index is initialized with constant expression.
	if (!isConstExpr(declInit->getRight())) {
		error(declInit->getLine(),
			  "Loop index cannot be initialized with non-constant expression",
			  symbol->getSymbol().c_str());
		return false;
	}

	info->index.id = symbol->getId();
	return true;
}

bool ValidateLimitations::validateForLoopCond(TIntermLoop* node,
                                              TLoopInfo* info)
{
	TIntermNode* cond = node->getCondition();
	if (!cond) {
		error(node->getLine(), "Missing condition", "for");
		return false;
	}
	//
	// condition has the form:
	//     loop_index relational_operator constant_expression
	//
	TIntermBinary* binOp = cond->getAsBinaryNode();
	if (!binOp) {
		error(node->getLine(), "Invalid condition", "for");
		return false;
	}
	// Loop index should be to the left of relational operator.
	TIntermSymbol* symbol = binOp->getLeft()->getAsSymbolNode();
	if (!symbol) {
		error(binOp->getLine(), "Invalid condition", "for");
		return false;
	}
	if (symbol->getId() != info->index.id) {
		error(symbol->getLine(),
			  "Expected loop index", symbol->getSymbol().c_str());
		return false;
	}
	// Relational operator is one of: > >= < <= == or !=.
	switch (binOp->getOp()) {
	case EOpEqual:
	case EOpNotEqual:
	case EOpLessThan:
	case EOpGreaterThan:
	case EOpLessThanEqual:
	case EOpGreaterThanEqual:
		break;
	default:
		error(binOp->getLine(),
			  "Invalid relational operator",
			  getOperatorString(binOp->getOp()));
		break;
	}
	// Loop index must be compared with a constant.
	if (!isConstExpr(binOp->getRight())) {
		error(binOp->getLine(),
			  "Loop index cannot be compared with non-constant expression",
			  symbol->getSymbol().c_str());
		return false;
	}

	return true;
}

bool ValidateLimitations::validateForLoopExpr(TIntermLoop* node,
                                              TLoopInfo* info)
{
	TIntermNode* expr = node->getExpression();
	if (!expr) {
		error(node->getLine(), "Missing expression", "for");
		return false;
	}

	// for expression has one of the following forms:
	//     loop_index++
	//     loop_index--
	//     loop_index += constant_expression
	//     loop_index -= constant_expression
	//     ++loop_index
	//     --loop_index
	// The last two forms are not specified in the spec, but I am assuming
	// its an oversight.
	TIntermUnary* unOp = expr->getAsUnaryNode();
	TIntermBinary* binOp = unOp ? nullptr : expr->getAsBinaryNode();

	TOperator op = EOpNull;
	TIntermSymbol* symbol = nullptr;
	if (unOp) {
		op = unOp->getOp();
		symbol = unOp->getOperand()->getAsSymbolNode();
	} else if (binOp) {
		op = binOp->getOp();
		symbol = binOp->getLeft()->getAsSymbolNode();
	}

	// The operand must be loop index.
	if (!symbol) {
		error(expr->getLine(), "Invalid expression", "for");
		return false;
	}
	if (symbol->getId() != info->index.id) {
		error(symbol->getLine(),
			  "Expected loop index", symbol->getSymbol().c_str());
		return false;
	}

	// The operator is one of: ++ -- += -=.
	switch (op) {
		case EOpPostIncrement:
		case EOpPostDecrement:
		case EOpPreIncrement:
		case EOpPreDecrement:
			ASSERT((unOp != NULL) && (binOp == NULL));
			break;
		case EOpAddAssign:
		case EOpSubAssign:
			ASSERT((unOp == NULL) && (binOp != NULL));
			break;
		default:
			error(expr->getLine(), "Invalid operator", getOperatorString(op));
			return false;
	}

	// Loop index must be incremented/decremented with a constant.
	if (binOp != NULL) {
		if (!isConstExpr(binOp->getRight())) {
			error(binOp->getLine(),
				  "Loop index cannot be modified by non-constant expression",
				  symbol->getSymbol().c_str());
			return false;
		}
	}

	return true;
}

bool ValidateLimitations::validateFunctionCall(TIntermAggregate* node)
{
	ASSERT(node->getOp() == EOpFunctionCall);

	// If not within loop body, there is nothing to check.
	if (!withinLoopBody())
		return true;

	// List of param indices for which loop indices are used as argument.
	typedef std::vector<int> ParamIndex;
	ParamIndex pIndex;
	TIntermSequence& params = node->getSequence();
	for (TIntermSequence::size_type i = 0; i < params.size(); ++i) {
		TIntermSymbol* symbol = params[i]->getAsSymbolNode();
		if (symbol && isLoopIndex(symbol))
			pIndex.push_back(i);
	}
	// If none of the loop indices are used as arguments,
	// there is nothing to check.
	if (pIndex.empty())
		return true;

	bool valid = true;
	TSymbolTable& symbolTable = GetGlobalParseContext()->symbolTable;
	TSymbol* symbol = symbolTable.find(node->getName(), GetGlobalParseContext()->getShaderVersion());
	ASSERT(symbol && symbol->isFunction());
	TFunction* function = static_cast<TFunction*>(symbol);
	for (ParamIndex::const_iterator i = pIndex.begin();
		 i != pIndex.end(); ++i) {
		const TParameter& param = function->getParam(*i);
		TQualifier qual = param.type->getQualifier();
		if ((qual == EvqOut) || (qual == EvqInOut)) {
			error(params[*i]->getLine(),
				  "Loop index cannot be used as argument to a function out or inout parameter",
				  params[*i]->getAsSymbolNode()->getSymbol().c_str());
			valid = false;
		}
	}

	return valid;
}

bool ValidateLimitations::validateOperation(TIntermOperator* node,
                                            TIntermNode* operand) {
	// Check if loop index is modified in the loop body.
	if (!withinLoopBody() || !node->modifiesState())
		return true;

	const TIntermSymbol* symbol = operand->getAsSymbolNode();
	if (symbol && isLoopIndex(symbol)) {
		error(node->getLine(),
			  "Loop index cannot be statically assigned to within the body of the loop",
			  symbol->getSymbol().c_str());
	}
	return true;
}

bool ValidateLimitations::isConstExpr(TIntermNode* node)
{
	ASSERT(node);
	return node->getAsConstantUnion() != nullptr;
}

bool ValidateLimitations::isConstIndexExpr(TIntermNode* node)
{
	ASSERT(node);

	ValidateConstIndexExpr validate(mLoopStack);
	node->traverse(&validate);
	return validate.isValid();
}

bool ValidateLimitations::validateIndexing(TIntermBinary* node)
{
	ASSERT((node->getOp() == EOpIndexDirect) ||
	       (node->getOp() == EOpIndexIndirect));

	bool valid = true;
	TIntermTyped* index = node->getRight();
	// The index expression must have integral type.
	if (!index->isScalarInt()) {
		error(index->getLine(),
		      "Index expression must have integral type",
		      index->getCompleteString().c_str());
		valid = false;
	}
	// The index expession must be a constant-index-expression unless
	// the operand is a uniform in a vertex shader.
	TIntermTyped* operand = node->getLeft();
	bool skip = (mShaderType == GL_VERTEX_SHADER) &&
	            (operand->getQualifier() == EvqUniform);
	if (!skip && !isConstIndexExpr(index)) {
		error(index->getLine(), "Index expression must be constant", "[]");
		valid = false;
	}
	return valid;
}

